{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "u4UNB9a6jCkU"
   },
   "source": [
    "# L5-C: Unpacking 2-Bit Weights\n",
    "\n",
    "In this lesson, you will learn how to \"unpack\" the stored low precision \"packed\" weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 30
   },
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "cEu4XvK3NAef"
   },
   "source": [
    "## Unpacking"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Note:** Younes will explain the below code, and walk through each iteration step. You can go through the comprehensive explaination in the markdown below after first watching Younes's explaination.\n",
    "\n",
    "```Python\n",
    "# Example Tensor: [10110001]\n",
    "    # Which was Originally: 1 0 3 2 - 01 00 11 10\n",
    "\n",
    "    # Starting point of unpacked Tensor\n",
    "    # [00000000 00000000 00000000 00000000]\n",
    "    \n",
    "    ##### First Iteration Start:\n",
    "    # packed int8 Tensor: [10110001]\n",
    "    # You want to extract 01 from [101100 01]\n",
    "    # No right shifts in the First Iteration\n",
    "    # After bit-wise OR operation between 00000000 and 10110001:\n",
    "    # [10110001 00000000 00000000 00000000]\n",
    "    # unpacked Tensor state: [10110001 00000000 00000000 00000000]\n",
    "    ##### First Iteration End\n",
    "\n",
    "    ##### Second Iteration Start:\n",
    "    # packed int8 Tensor: [10110001]\n",
    "    # You want to extract 00 from [1011 00 01]\n",
    "    # 2 right shifts:\n",
    "    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100\n",
    "    # After bit-wise OR operation between 00000000 and 00101100:\n",
    "    # [10110001 00101100 00000000 00000000]\n",
    "    # unpacked Tensor state: [10110001 00101100 00000000 00000000]\n",
    "    ##### Second Iteration End\n",
    "\n",
    "    ##### Third Iteration Start:\n",
    "    # packed int8 Tensor: [10110001]\n",
    "    # You want to extract 11 from [10 11 0001]\n",
    "    # 4 right shifts:\n",
    "    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100\n",
    "    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011\n",
    "    # After bit-wise OR operation between 00000000 and 00001011:\n",
    "    # [10110001 00101100 00001011 00000000]\n",
    "    # unpacked Tensor state: [10110001 00101100 00001011 00000000]\n",
    "    ##### Third Iteration End\n",
    "\n",
    "    ##### Fourth Iteration Start:\n",
    "    # packed int8 Tensor: [10110001]\n",
    "    # You want to extract 10 from [10 110001]\n",
    "    # 6 right shifts:\n",
    "    # [10110001] (1 shift)-> 01011000 (2 shift)-> 00101100\n",
    "    # 00101100 (3 shift)-> 00010110 (4 shift)-> 00001011\n",
    "    # 00001011 (5 shift)-> 00000101 (6 shift)-> 00000010\n",
    "    # After bit-wise OR operation between 00000000 and 00000010:\n",
    "    # [10110001 00101100 00001011 00000010]\n",
    "    # unpacked Tensor state: [10110001 00101100 00001011 00000010]\n",
    "    ##### Fourth Iteration End\n",
    "    \n",
    "    # Last step: Perform masking (bit-wise AND operation)\n",
    "    # Mask: 00000011\n",
    "    # Bit-wise AND operation between \n",
    "    # unpacked Tensor and 00000011\n",
    "    # [10110001 00101100 00001011 00000010] <- unpacked tensor\n",
    "    # [00000011 00000011 00000011 00000011] <- Mask\n",
    "    # [00000001 00000000 00000011 00000010] <- Result\n",
    "\n",
    "    # Final\n",
    "    # unpacked Tensor state: [00000001 00000000 00000011 00000010]\n",
    "\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "executionInfo": {
     "elapsed": 214,
     "status": "ok",
     "timestamp": 1705676421893,
     "user": {
      "displayName": "Younes Belkada",
      "userId": "15414910276690549281"
     },
     "user_tz": 0
    },
    "height": 590,
    "id": "uEAWcVoONaIy"
   },
   "outputs": [],
   "source": [
    "def unpack_weights(uint8tensor, bits):\n",
    "    num_values = uint8tensor.shape[0] * 8 // bits\n",
    "\n",
    "    num_steps = 8 // bits\n",
    "\n",
    "    unpacked_tensor = torch.zeros((num_values), dtype=torch.uint8)\n",
    "\n",
    "    unpacked_idx = 0\n",
    "\n",
    "    # 1 0 3 2 - 01 00 11 10\n",
    "\n",
    "    # [00000000 00000000 00000000 00000000]\n",
    "    # [10110001 00101100 00001011 00000010]\n",
    "    # [00000001 00000000 00000011 00000010]\n",
    "\n",
    "    # 10110001\n",
    "    # 00000011\n",
    "    \n",
    "    # 00000001\n",
    "\n",
    "    # 1: [10110001]\n",
    "    # 2: [00101100]\n",
    "    # 3: [00001011]\n",
    "\n",
    "    mask = 2 ** bits - 1\n",
    "\n",
    "    for i in range(uint8tensor.shape[0]):\n",
    "        for j in range(num_steps):\n",
    "            unpacked_tensor[unpacked_idx] |= uint8tensor[i] >> (bits * j)\n",
    "            unpacked_idx += 1\n",
    "\n",
    "    unpacked_tensor &= mask\n",
    "    return unpacked_tensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 254,
     "status": "ok",
     "timestamp": 1705676423045,
     "user": {
      "displayName": "Younes Belkada",
      "userId": "15414910276690549281"
     },
     "user_tz": 0
    },
    "height": 46,
    "id": "6_9OAg88NyvF",
    "outputId": "d93e9b97-ad34-4127-dc40-de45daaea24c"
   },
   "outputs": [],
   "source": [
    "unpacked_tensor = torch.tensor([177, 255], \n",
    "                               dtype=torch.uint8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 46,
    "id": "E6EFomfQPD6m"
   },
   "outputs": [],
   "source": [
    "# Answer should be: torch.tensor([1, 0, 3, 2, 3, 3, 3, 3]\n",
    "unpack_weights(unpacked_tensor, 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "authorship_tag": "ABX9TyPkr3WbwaCNWYoK5KkH4mJ5",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
